compute loss only if training and update token metric naming#3293
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~15–20 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)
351-369: Gating tkps tracking onmodel.trainingaligns with intent; consider edge‑cases and minor cleanupConditioning the token accounting on
model.trainingachieves the goal of only tracking tokens during actual training steps and avoids polluting metrics during eval/inference. This should be compatible with the standardTrainer.train()/Trainer.evaluate()flow where the trainer togglesmodel.train()/model.eval().Two minor points to keep in mind:
- If you have any custom code paths that call
compute_lossfor “training-like” work while leavingmodel.training == False(e.g., manual scoring runs), tkps will now be skipped there; worth confirming that you don’t rely on tkps in those flows.- Optional micro‑refactor: you can avoid recomputing the mask sum when updating
self.state.num_tokensby reusingnum_tokens(or its.cpu()copy), e.g.:- if self.args.include_tkps and model.training: - inputs_key = "labels" if "labels" in inputs else "input_ids" - num_tokens = (inputs[inputs_key] != -100).sum() + if self.args.include_tkps and model.training: + inputs_key = "labels" if "labels" in inputs else "input_ids" + num_tokens = (inputs[inputs_key] != -100).sum() + num_tokens_cpu = num_tokens.cpu() @@ - if hasattr(self.state, "num_tokens"): - self.state.num_tokens = ( - self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() - ) - else: - self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() + if hasattr(self.state, "num_tokens"): + self.state.num_tokens = self.state.num_tokens + num_tokens_cpu + else: + self.state.num_tokens = num_tokens_cpu
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/core/trainers/base.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.9.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| tokens_state = json.load(f) | ||
| state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0)) | ||
| state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0)) | ||
| LOG.info(f"Restored total_tokens: {state.total_tokens}") |
There was a problem hiding this comment.
Is this how we should store total_tokens? Are we able to inject it into the trainer state file that gets created in checkpoints?
There was a problem hiding this comment.
ummm TrainerState is a dataclass, it only serializes defined fields
NanoCode012
left a comment
There was a problem hiding this comment.
As mentioned in chat, let's refactor to use tokens/total and tokens/trainable. I believe self.num_tokens is redundant?
| self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu() | ||
| ) | ||
| # Store per-step trainable tokens for throughput calculation | ||
| self.state.tokens["trainable_step"] = trainable_tokens.detach().cpu() |
There was a problem hiding this comment.
Can be removed as is unused (not logged and too similar to others)
| self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 | ||
| ) | ||
| logs["total_tokens"] = int(self.state.total_tokens.item()) | ||
| logs["tokens/total"] = int(self.state.tokens["total"].item()) |
There was a problem hiding this comment.
I think we missed log tokens/trainable
| if tokens and "total" in tokens: | ||
| logs["tokens/total"] = tokens["total"].item() | ||
|
|
||
| if tokens and "trainable" in tokens: | ||
| logs["tokens/trainable"] = tokens["trainable"].item() |
There was a problem hiding this comment.
Is this duplicate log of base.py L651-652?
|
can we add a CI similar to TestResumeLlama that on resume the total tokens is correct? |
|
Okay! |
* upgrade dependencies * don't use reset sessions * downgrade transformers, upgrade other deps * upgrade bnb to 0.49.0 * restore s3 cache * explicit use local files w hub * decompress and strip top level dir * use 2 levels for strip components * try to preserve permissions for symlinks * use updated tar * fix #3293 for distributed * downgrade bnb * fast fail after 4 * fix total tokens device * patch accelerate CP/SP (#3309) --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
|
|




fixes #3291

Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.